{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Item Response Ranking for NCDM\n",
    "\n",
    "This notebook will show you how to train and use the IRR-NCDM.\n",
    "Refer to [IRR doc](../../docs/IRR.md) for more details.\n",
    "First, we will show how to get the data (here we use a0910 as the dataset).\n",
    "Then we will show how to train a IRR-NCDM and perform the parameters persistence.\n",
    "At last, we will show how to load the parameters from the file and evaluate on the test dataset."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "reading records from ../../data/a0910/item.csv: 100%|██████████| 19529/19529 [00:00<00:00, 56803.95it/s]\n",
      "rating2triplet: 100%|██████████| 17051/17051 [00:15<00:00, 1073.69it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": "(<longling.lib.iterator.AsyncLoopIter at 0x19579f67af0>,\n <torch.utils.data.dataloader.DataLoader at 0x195779a4e50>,\n <torch.utils.data.dataloader.DataLoader at 0x1957b015a00>)"
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import logging\n",
    "from longling.lib.structure import AttrDict\n",
    "from longling import set_logging_info\n",
    "from EduCDM.IRR import pair_etl as etl, point_etl as vt_etl, extract_item\n",
    "\n",
    "set_logging_info()\n",
    "\n",
    "params = AttrDict(\n",
    "    batch_size=256,\n",
    "    n_neg=10,\n",
    "    n_imp=10,\n",
    "    logger=logging.getLogger(),\n",
    "    hyper_params={\"user_num\": 4164, \"knowledge_num\": 123}\n",
    ")\n",
    "item_knowledge = extract_item(\"../../data/a0910/item.csv\", params[\"hyper_params\"][\"knowledge_num\"], params)\n",
    "train_data, train_df = etl(\"../../data/a0910/train.csv\", item_knowledge, params)\n",
    "valid_data, _ = vt_etl(\"../../data/a0910/valid.csv\", item_knowledge, params)\n",
    "test_data, _ = vt_etl(\"../../data/a0910/test.csv\", item_knowledge, params)\n",
    "\n",
    "train_data, valid_data, test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "data": {
      "text/plain": "        user_id  item_id  score\n0          1615    12977    1.0\n1           782    13124    0.0\n2          1084    16475    0.0\n3           593     8690    0.0\n4           127    14225    1.0\n...         ...      ...    ...\n186044     2280     6019    0.0\n186045      121        2    1.0\n186046      601     5425    1.0\n186047      573     2412    0.0\n186048       60     2969    1.0\n\n[186049 rows x 3 columns]",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>user_id</th>\n      <th>item_id</th>\n      <th>score</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>1615</td>\n      <td>12977</td>\n      <td>1.0</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>782</td>\n      <td>13124</td>\n      <td>0.0</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>1084</td>\n      <td>16475</td>\n      <td>0.0</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>593</td>\n      <td>8690</td>\n      <td>0.0</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>127</td>\n      <td>14225</td>\n      <td>1.0</td>\n    </tr>\n    <tr>\n      <th>...</th>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n    </tr>\n    <tr>\n      <th>186044</th>\n      <td>2280</td>\n      <td>6019</td>\n      <td>0.0</td>\n    </tr>\n    <tr>\n      <th>186045</th>\n      <td>121</td>\n      <td>2</td>\n      <td>1.0</td>\n    </tr>\n    <tr>\n      <th>186046</th>\n      <td>601</td>\n      <td>5425</td>\n      <td>1.0</td>\n    </tr>\n    <tr>\n      <th>186047</th>\n      <td>573</td>\n      <td>2412</td>\n      <td>0.0</td>\n    </tr>\n    <tr>\n      <th>186048</th>\n      <td>60</td>\n      <td>2969</td>\n      <td>1.0</td>\n    </tr>\n  </tbody>\n</table>\n<p>186049 rows × 3 columns</p>\n</div>"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 727it [03:23,  3.57it/s]\n",
      "evaluating: 100%|██████████| 101/101 [00:00<00:00, 115.12it/s]\n",
      "formatting item df: 100%|██████████| 10415/10415 [00:00<00:00, 12423.18it/s]\n",
      "ranking metrics: 10415it [00:14, 735.24it/s]\n",
      "Epoch 1: 100%|██████████| 727/727 [02:49<00:00,  4.29it/s]\n",
      "evaluating: 100%|██████████| 101/101 [00:01<00:00, 91.61it/s]\n",
      "formatting item df: 100%|██████████| 10415/10415 [00:00<00:00, 11477.08it/s]\n",
      "ranking metrics: 10415it [00:14, 707.66it/s]\n",
      "INFO:root:save parameters to IRR-NCDM.params\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] Loss: 2.558156, PointLoss: 0.647514, PairLoss: 4.468798\n",
      "[Epoch 0]\n",
      "      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k\n",
      "1   1.000000     0.704849  0.493072  0.547682  1.000000      10415\n",
      "3   0.895209     0.681741  0.743652  0.691947  1.906961      10415\n",
      "5   0.894566     0.676585  0.796747  0.713769  2.229573      10415\n",
      "10  0.893579     0.674508  0.816654  0.720577  2.423428      10415\n",
      "auc: 0.866700\tmap: 0.855576\tmrr: 0.924757\tcoverage_error: 3.349563\tranking_loss: 0.564976\tlen: 2.458569\tsupport: 10415\n",
      "[Epoch 1] Loss: 2.555617, PointLoss: 0.644294, PairLoss: 4.466940\n",
      "[Epoch 1]\n",
      "      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k\n",
      "1   1.000000     0.704849  0.493072  0.547682  1.000000      10415\n",
      "3   0.895237     0.681741  0.743652  0.691947  1.906961      10415\n",
      "5   0.894598     0.676585  0.796747  0.713769  2.229573      10415\n",
      "10  0.893612     0.674508  0.816654  0.720577  2.423428      10415\n",
      "auc: 0.866700\tmap: 0.855650\tmrr: 0.924757\tcoverage_error: 3.349060\tranking_loss: 0.564641\tlen: 2.458569\tsupport: 10415\n"
     ]
    }
   ],
   "source": [
    "from EduCDM.IRR import NCDM\n",
    "\n",
    "cdm = NCDM(\n",
    "    4163 + 1,\n",
    "    17746 + 1,\n",
    "    123,\n",
    ")\n",
    "cdm.train(\n",
    "    train_data,\n",
    "    valid_data,\n",
    "    epoch=2,\n",
    ")\n",
    "cdm.save(\"IRR-NCDM.params\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:load parameters from IRR-NCDM.params\n",
      "evaluating: 100%|██████████| 218/218 [00:01<00:00, 169.59it/s]\n",
      "formatting item df: 100%|██████████| 13682/13682 [00:01<00:00, 12955.64it/s]\n",
      "ranking metrics: 13682it [00:23, 571.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k\n",
      "1   1.000000     0.703552  0.386927  0.454626  1.000000      13682\n",
      "3   0.871698     0.683635  0.676587  0.646876  2.268528      13682\n",
      "5   0.871883     0.674940  0.777347  0.696432  2.981582      13682\n",
      "10  0.869815     0.669786  0.847081  0.725684  3.723652      13682\n",
      "auc: 0.804861\tmap: 0.803387\tmrr: 0.895832\tcoverage_error: 5.059019\tranking_loss: 0.636236\tlen: 4.075428\tsupport: 13682\n"
     ]
    }
   ],
   "source": [
    "cdm.load(\"IRR-NCDM.params\")\n",
    "print(cdm.eval(test_data))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}